{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "This is a companion notebook for the book [Deep Learning with Python, Third Edition](https://www.manning.com/books/deep-learning-with-python-third-edition). For readability, it only contains runnable code blocks and section titles, and omits everything else in the book: text paragraphs, figures, and pseudocode.\n\n**If you want to be able to follow what's going on, I recommend reading the notebook side by side with your copy of the book.**\n\nThe book's contents are available online at [deeplearningwithpython.io](https://deeplearningwithpython.io)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "!pip install keras keras-hub --upgrade -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# @title\n",
    "import os\n",
    "from IPython.core.magic import register_cell_magic\n",
    "\n",
    "@register_cell_magic\n",
    "def backend(line, cell):\n",
    "    current, required = os.environ.get(\"KERAS_BACKEND\", \"\"), line.split()[-1]\n",
    "    if current == required:\n",
    "        get_ipython().run_cell(cell)\n",
    "    else:\n",
    "        print(\n",
    "            f\"This cell requires the {required} backend. To run it, change KERAS_BACKEND to \"\n",
    "            f\"\\\"{required}\\\" at the top of the notebook, restart the runtime, and rerun the notebook.\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Image segmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Computer vision tasks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Types of image segmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Training a segmentation model from scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Downloading a segmentation dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "!wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz\n",
    "!wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz\n",
    "!tar -xf images.tar.gz\n",
    "!tar -xf annotations.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import pathlib\n",
    "\n",
    "input_dir = pathlib.Path(\"images\")\n",
    "target_dir = pathlib.Path(\"annotations/trimaps\")\n",
    "\n",
    "input_img_paths = sorted(input_dir.glob(\"*.jpg\"))\n",
    "target_paths = sorted(target_dir.glob(\"[!.]*.png\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from keras.utils import load_img, img_to_array, array_to_img\n",
    "\n",
    "plt.axis(\"off\")\n",
    "plt.imshow(load_img(input_img_paths[9]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def display_target(target_array):\n",
    "    normalized_array = (target_array.astype(\"uint8\") - 1) * 127\n",
    "    plt.axis(\"off\")\n",
    "    plt.imshow(normalized_array[:, :, 0])\n",
    "\n",
    "img = img_to_array(load_img(target_paths[9], color_mode=\"grayscale\"))\n",
    "display_target(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "\n",
    "img_size = (200, 200)\n",
    "num_imgs = len(input_img_paths)\n",
    "\n",
    "random.Random(1337).shuffle(input_img_paths)\n",
    "random.Random(1337).shuffle(target_paths)\n",
    "\n",
    "def path_to_input_image(path):\n",
    "    return img_to_array(load_img(path, target_size=img_size))\n",
    "\n",
    "def path_to_target(path):\n",
    "    img = img_to_array(\n",
    "        load_img(path, target_size=img_size, color_mode=\"grayscale\")\n",
    "    )\n",
    "    img = img.astype(\"uint8\") - 1\n",
    "    return img\n",
    "\n",
    "input_imgs = np.zeros((num_imgs,) + img_size + (3,), dtype=\"float32\")\n",
    "targets = np.zeros((num_imgs,) + img_size + (1,), dtype=\"uint8\")\n",
    "for i in range(num_imgs):\n",
    "    input_imgs[i] = path_to_input_image(input_img_paths[i])\n",
    "    targets[i] = path_to_target(target_paths[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "num_val_samples = 1000\n",
    "train_input_imgs = input_imgs[:-num_val_samples]\n",
    "train_targets = targets[:-num_val_samples]\n",
    "val_input_imgs = input_imgs[-num_val_samples:]\n",
    "val_targets = targets[-num_val_samples:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Building and training the segmentation model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import keras\n",
    "from keras.layers import Rescaling, Conv2D, Conv2DTranspose\n",
    "\n",
    "def get_model(img_size, num_classes):\n",
    "    inputs = keras.Input(shape=img_size + (3,))\n",
    "    x = Rescaling(1.0 / 255)(inputs)\n",
    "\n",
    "    x = Conv2D(64, 3, strides=2, activation=\"relu\", padding=\"same\")(x)\n",
    "    x = Conv2D(64, 3, activation=\"relu\", padding=\"same\")(x)\n",
    "    x = Conv2D(128, 3, strides=2, activation=\"relu\", padding=\"same\")(x)\n",
    "    x = Conv2D(128, 3, activation=\"relu\", padding=\"same\")(x)\n",
    "    x = Conv2D(256, 3, strides=2, padding=\"same\", activation=\"relu\")(x)\n",
    "    x = Conv2D(256, 3, activation=\"relu\", padding=\"same\")(x)\n",
    "\n",
    "    x = Conv2DTranspose(256, 3, activation=\"relu\", padding=\"same\")(x)\n",
    "    x = Conv2DTranspose(256, 3, strides=2, activation=\"relu\", padding=\"same\")(x)\n",
    "    x = Conv2DTranspose(128, 3, activation=\"relu\", padding=\"same\")(x)\n",
    "    x = Conv2DTranspose(128, 3, strides=2, activation=\"relu\", padding=\"same\")(x)\n",
    "    x = Conv2DTranspose(64, 3, activation=\"relu\", padding=\"same\")(x)\n",
    "    x = Conv2DTranspose(64, 3, strides=2, activation=\"relu\", padding=\"same\")(x)\n",
    "\n",
    "    outputs = Conv2D(num_classes, 3, activation=\"softmax\", padding=\"same\")(x)\n",
    "\n",
    "    return keras.Model(inputs, outputs)\n",
    "\n",
    "model = get_model(img_size=img_size, num_classes=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# \u26a0\ufe0fNOTE\u26a0\ufe0f: The following IoU metric is *very* slow on the PyTorch backend!\n",
    "# If you are running with PyTorch, we recommend re-running the notebook with Jax\n",
    "# or TensorFlow, or skipping to the next section of this chapter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "foreground_iou = keras.metrics.IoU(\n",
    "    num_classes=3,\n",
    "    target_class_ids=(0,),\n",
    "    name=\"foreground_iou\",\n",
    "    sparse_y_true=True,\n",
    "    sparse_y_pred=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.compile(\n",
    "    optimizer=\"adam\",\n",
    "    loss=\"sparse_categorical_crossentropy\",\n",
    "    metrics=[foreground_iou],\n",
    ")\n",
    "callbacks = [\n",
    "    keras.callbacks.ModelCheckpoint(\n",
    "        \"oxford_segmentation.keras\",\n",
    "        save_best_only=True,\n",
    "    ),\n",
    "]\n",
    "history = model.fit(\n",
    "    train_input_imgs,\n",
    "    train_targets,\n",
    "    epochs=50,\n",
    "    callbacks=callbacks,\n",
    "    batch_size=64,\n",
    "    validation_data=(val_input_imgs, val_targets),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "epochs = range(1, len(history.history[\"loss\"]) + 1)\n",
    "loss = history.history[\"loss\"]\n",
    "val_loss = history.history[\"val_loss\"]\n",
    "plt.figure()\n",
    "plt.plot(epochs, loss, \"r--\", label=\"Training loss\")\n",
    "plt.plot(epochs, val_loss, \"b\", label=\"Validation loss\")\n",
    "plt.title(\"Training and validation loss\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model = keras.models.load_model(\"oxford_segmentation.keras\")\n",
    "\n",
    "i = 4\n",
    "test_image = val_input_imgs[i]\n",
    "plt.axis(\"off\")\n",
    "plt.imshow(array_to_img(test_image))\n",
    "\n",
    "mask = model.predict(np.expand_dims(test_image, 0))[0]\n",
    "\n",
    "def display_mask(pred):\n",
    "    mask = np.argmax(pred, axis=-1)\n",
    "    mask *= 127\n",
    "    plt.axis(\"off\")\n",
    "    plt.imshow(mask)\n",
    "\n",
    "display_mask(mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Using a pretrained segmentation model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Downloading the Segment Anything Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import keras_hub\n",
    "\n",
    "model = keras_hub.models.ImageSegmenter.from_preset(\"sam_huge_sa1b\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.count_params()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### How Segment Anything works"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Preparing a test image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "path = keras.utils.get_file(\n",
    "    origin=\"https://s3.amazonaws.com/keras.io/img/book/fruits.jpg\"\n",
    ")\n",
    "pil_image = keras.utils.load_img(path)\n",
    "image_array = keras.utils.img_to_array(pil_image)\n",
    "\n",
    "plt.imshow(image_array.astype(\"uint8\"))\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from keras import ops\n",
    "\n",
    "image_size = (1024, 1024)\n",
    "\n",
    "def resize_and_pad(x):\n",
    "    return ops.image.resize(x, image_size, pad_to_aspect_ratio=True)\n",
    "\n",
    "image = resize_and_pad(image_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from keras import ops\n",
    "\n",
    "def show_image(image, ax):\n",
    "    ax.imshow(ops.convert_to_numpy(image).astype(\"uint8\"))\n",
    "\n",
    "def show_mask(mask, ax):\n",
    "    color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])\n",
    "    h, w, _ = mask.shape\n",
    "    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
    "    ax.imshow(mask_image)\n",
    "\n",
    "def show_points(points, ax):\n",
    "    x, y = points[:, 0], points[:, 1]\n",
    "    ax.scatter(x, y, c=\"green\", marker=\"*\", s=375, ec=\"white\", lw=1.25)\n",
    "\n",
    "def show_box(box, ax):\n",
    "    box = box.reshape(-1)\n",
    "    x0, y0 = box[0], box[1]\n",
    "    w, h = box[2] - box[0], box[3] - box[1]\n",
    "    ax.add_patch(plt.Rectangle((x0, y0), w, h, ec=\"red\", fc=\"none\", lw=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Prompting the model with a target point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "input_point = np.array([[580, 450]])\n",
    "input_label = np.array([1])\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "show_image(image, plt.gca())\n",
    "show_points(input_point, plt.gca())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "outputs = model.predict(\n",
    "    {\n",
    "        \"images\": ops.expand_dims(image, axis=0),\n",
    "        \"points\": ops.expand_dims(input_point, axis=0),\n",
    "        \"labels\": ops.expand_dims(input_label, axis=0),\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "outputs[\"masks\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def get_mask(sam_outputs, index=0):\n",
    "    mask = sam_outputs[\"masks\"][0][index]\n",
    "    mask = np.expand_dims(mask, axis=-1)\n",
    "    mask = resize_and_pad(mask)\n",
    "    return ops.convert_to_numpy(mask) > 0.0\n",
    "\n",
    "mask = get_mask(outputs, index=0)\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "show_image(image, plt.gca())\n",
    "show_mask(mask, plt.gca())\n",
    "show_points(input_point, plt.gca())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "input_point = np.array([[300, 550]])\n",
    "input_label = np.array([1])\n",
    "\n",
    "outputs = model.predict(\n",
    "    {\n",
    "        \"images\": ops.expand_dims(image, axis=0),\n",
    "        \"points\": ops.expand_dims(input_point, axis=0),\n",
    "        \"labels\": ops.expand_dims(input_label, axis=0),\n",
    "    }\n",
    ")\n",
    "mask = get_mask(outputs, index=0)\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "show_image(image, plt.gca())\n",
    "show_mask(mask, plt.gca())\n",
    "show_points(input_point, plt.gca())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 3, figsize=(20, 60))\n",
    "masks = outputs[\"masks\"][0][1:]\n",
    "for i, mask in enumerate(masks):\n",
    "    show_image(image, axes[i])\n",
    "    show_points(input_point, axes[i])\n",
    "    mask = get_mask(outputs, index=i + 1)\n",
    "    show_mask(mask, axes[i])\n",
    "    axes[i].set_title(f\"Mask {i + 1}\", fontsize=16)\n",
    "    axes[i].axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Prompting the model with a target box"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "input_box = np.array(\n",
    "    [\n",
    "        [520, 180],\n",
    "        [770, 420],\n",
    "    ]\n",
    ")\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "show_image(image, plt.gca())\n",
    "show_box(input_box, plt.gca())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "outputs = model.predict(\n",
    "    {\n",
    "        \"images\": ops.expand_dims(image, axis=0),\n",
    "        \"boxes\": ops.expand_dims(input_box, axis=(0, 1)),\n",
    "    }\n",
    ")\n",
    "mask = get_mask(outputs, 0)\n",
    "plt.figure(figsize=(10, 10))\n",
    "show_image(image, plt.gca())\n",
    "show_mask(mask, plt.gca())\n",
    "show_box(input_box, plt.gca())\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "chapter11_image-segmentation",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}